{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\"\"\"\n",
                "Please run notebook locally (if you have all the dependencies and a GPU). \n",
                "Technically you can run this notebook on Google Colab but you need to set up microphone for Colab.\n",
                "\n",
                "Instructions for setting up Colab are as follows:\n",
                "1. Open a new Python 3 notebook.\n",
                "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
                "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
                "4. Run this cell to set up dependencies.\n",
                "5. Set up microphone for Colab\n",
                "\"\"\"\n",
                "# If you're using Google Colab and not running locally, run this cell.\n",
                "\n",
                "## Install dependencies\n",
                "!pip install wget\n",
                "!apt-get install sox libsndfile1 ffmpeg portaudio19-dev\n",
                "!pip install unidecode\n",
                "!pip install pyaudio\n",
                "\n",
                "# ## Install NeMo\n",
                "BRANCH = 'main'\n",
                "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]\n",
                "\n",
                "## Install TorchAudio\n",
                "!pip install torchaudio>=0.10.0 -f https://download.pytorch.org/whl/torch_stable.html\n",
                "\n",
                "## Grab the config we'll use in this example\n",
                "!mkdir configs"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "This notebook demonstrates automatic speech recognition (ASR) from a microphone's stream in NeMo.\n",
                "\n",
                "It is **not a recommended** way to do inference in production workflows. If you are interested in \n",
                "production-level inference using NeMo ASR models, please refer to NVIDIA RIVA: https://developer.nvidia.com/riva"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "The notebook requires PyAudio library to get a signal from an audio device.\n",
                "For Ubuntu, please run the following commands to install it:\n",
                "```\n",
                "sudo apt-get install -y portaudio19-dev\n",
                "pip install pyaudio\n",
                "```"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "This notebook requires the `torchaudio` library to be installed for MatchboxNet. Please follow the instructions available at the [torchaudio Github page](https://github.com/pytorch/audio#installation) to install the appropriate version of torchaudio.\n",
                "\n",
                "If you would like to install the latest version, please run the following command to install it:\n",
                "\n",
                "```\n",
                "conda install -c pytorch torchaudio\n",
                "```"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import numpy as np\n",
                "import pyaudio as pa\n",
                "import os, time"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import nemo\n",
                "import nemo.collections.asr as nemo_asr"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# sample rate, Hz\n",
                "SAMPLE_RATE = 16000"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Restore the model from NGC"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained('QuartzNet15x5Base-En')"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Observing the config of the model"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from omegaconf import OmegaConf\n",
                "import copy"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "scrolled": true
            },
            "outputs": [],
            "source": [
                "# Preserve a copy of the full config\n",
                "cfg = copy.deepcopy(asr_model._cfg)\n",
                "print(OmegaConf.to_yaml(cfg))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Modify preprocessor parameters for inference"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Make config overwrite-able\n",
                "OmegaConf.set_struct(cfg.preprocessor, False)\n",
                "\n",
                "# some changes for streaming scenario\n",
                "cfg.preprocessor.dither = 0.0\n",
                "cfg.preprocessor.pad_to = 0\n",
                "\n",
                "# spectrogram normalization constants\n",
                "normalization = {}\n",
                "normalization['fixed_mean'] = [\n",
                "     -14.95827016, -12.71798736, -11.76067913, -10.83311182,\n",
                "     -10.6746914,  -10.15163465, -10.05378331, -9.53918999,\n",
                "     -9.41858904,  -9.23382904,  -9.46470918,  -9.56037,\n",
                "     -9.57434245,  -9.47498732,  -9.7635205,   -10.08113074,\n",
                "     -10.05454561, -9.81112681,  -9.68673603,  -9.83652977,\n",
                "     -9.90046248,  -9.85404766,  -9.92560366,  -9.95440354,\n",
                "     -10.17162966, -9.90102482,  -9.47471025,  -9.54416855,\n",
                "     -10.07109475, -9.98249912,  -9.74359465,  -9.55632283,\n",
                "     -9.23399915,  -9.36487649,  -9.81791084,  -9.56799225,\n",
                "     -9.70630899,  -9.85148006,  -9.8594418,   -10.01378735,\n",
                "     -9.98505315,  -9.62016094,  -10.342285,   -10.41070709,\n",
                "     -10.10687659, -10.14536695, -10.30828702, -10.23542833,\n",
                "     -10.88546868, -11.31723646, -11.46087382, -11.54877829,\n",
                "     -11.62400934, -11.92190509, -12.14063815, -11.65130117,\n",
                "     -11.58308531, -12.22214663, -12.42927197, -12.58039805,\n",
                "     -13.10098969, -13.14345864, -13.31835645, -14.47345634]\n",
                "normalization['fixed_std'] = [\n",
                "     3.81402054, 4.12647781, 4.05007065, 3.87790987,\n",
                "     3.74721178, 3.68377423, 3.69344,    3.54001005,\n",
                "     3.59530412, 3.63752368, 3.62826417, 3.56488469,\n",
                "     3.53740577, 3.68313898, 3.67138151, 3.55707266,\n",
                "     3.54919572, 3.55721289, 3.56723346, 3.46029304,\n",
                "     3.44119672, 3.49030548, 3.39328435, 3.28244406,\n",
                "     3.28001423, 3.26744937, 3.46692348, 3.35378948,\n",
                "     2.96330901, 2.97663111, 3.04575148, 2.89717604,\n",
                "     2.95659301, 2.90181116, 2.7111687,  2.93041291,\n",
                "     2.86647897, 2.73473181, 2.71495654, 2.75543763,\n",
                "     2.79174615, 2.96076456, 2.57376336, 2.68789782,\n",
                "     2.90930817, 2.90412004, 2.76187531, 2.89905006,\n",
                "     2.65896173, 2.81032176, 2.87769857, 2.84665271,\n",
                "     2.80863137, 2.80707634, 2.83752184, 3.01914511,\n",
                "     2.92046439, 2.78461139, 2.90034605, 2.94599508,\n",
                "     2.99099718, 3.0167554,  3.04649716, 2.94116777]\n",
                "\n",
                "cfg.preprocessor.normalize = normalization\n",
                "\n",
                "# Disable config overwriting\n",
                "OmegaConf.set_struct(cfg.preprocessor, True)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Setup preprocessor with these settings"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "asr_model.preprocessor = asr_model.from_config_dict(cfg.preprocessor)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Set model to inference mode\n",
                "asr_model.eval();"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "asr_model = asr_model.to(asr_model.device)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Setting up data for Streaming Inference"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from nemo.core.classes import IterableDataset\n",
                "from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType\n",
                "import torch\n",
                "from torch.utils.data import DataLoader"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# simple data layer to pass audio signal\n",
                "class AudioDataLayer(IterableDataset):\n",
                "    @property\n",
                "    def output_types(self):\n",
                "        return {\n",
                "            'audio_signal': NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),\n",
                "            'a_sig_length': NeuralType(tuple('B'), LengthsType()),\n",
                "        }\n",
                "\n",
                "    def __init__(self, sample_rate):\n",
                "        super().__init__()\n",
                "        self._sample_rate = sample_rate\n",
                "        self.output = True\n",
                "        \n",
                "    def __iter__(self):\n",
                "        return self\n",
                "    \n",
                "    def __next__(self):\n",
                "        if not self.output:\n",
                "            raise StopIteration\n",
                "        self.output = False\n",
                "        return torch.as_tensor(self.signal, dtype=torch.float32), \\\n",
                "               torch.as_tensor(self.signal_shape, dtype=torch.int64)\n",
                "        \n",
                "    def set_signal(self, signal):\n",
                "        self.signal = signal.astype(np.float32)/32768.\n",
                "        self.signal_shape = self.signal.size\n",
                "        self.output = True\n",
                "\n",
                "    def __len__(self):\n",
                "        return 1"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "data_layer = AudioDataLayer(sample_rate=cfg.preprocessor.sample_rate)\n",
                "data_loader = DataLoader(data_layer, batch_size=1, collate_fn=data_layer.collate_fn)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# inference method for audio signal (single instance)\n",
                "def infer_signal(model, signal):\n",
                "    data_layer.set_signal(signal)\n",
                "    batch = next(iter(data_loader))\n",
                "    audio_signal, audio_signal_len = batch\n",
                "    audio_signal, audio_signal_len = audio_signal.to(asr_model.device), audio_signal_len.to(asr_model.device)\n",
                "    log_probs, encoded_len, predictions = model.forward(\n",
                "        input_signal=audio_signal, input_signal_length=audio_signal_len\n",
                "    )\n",
                "    return log_probs"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# class for streaming frame-based ASR\n",
                "# 1) use reset() method to reset FrameASR's state\n",
                "# 2) call transcribe(frame) to do ASR on\n",
                "#    contiguous signal's frames\n",
                "class FrameASR:\n",
                "    \n",
                "    def __init__(self, model_definition,\n",
                "                 frame_len=2, frame_overlap=2.5, \n",
                "                 offset=10):\n",
                "        '''\n",
                "        Args:\n",
                "          frame_len: frame's duration, seconds\n",
                "          frame_overlap: duration of overlaps before and after current frame, seconds\n",
                "          offset: number of symbols to drop for smooth streaming\n",
                "        '''\n",
                "        self.vocab = list(model_definition['labels'])\n",
                "        self.vocab.append('_')\n",
                "        \n",
                "        self.sr = model_definition['sample_rate']\n",
                "        self.frame_len = frame_len\n",
                "        self.n_frame_len = int(frame_len * self.sr)\n",
                "        self.frame_overlap = frame_overlap\n",
                "        self.n_frame_overlap = int(frame_overlap * self.sr)\n",
                "        timestep_duration = model_definition['AudioToMelSpectrogramPreprocessor']['window_stride']\n",
                "        for block in model_definition['JasperEncoder']['jasper']:\n",
                "            timestep_duration *= block['stride'][0] ** block['repeat']\n",
                "        self.n_timesteps_overlap = int(frame_overlap / timestep_duration) - 2\n",
                "        self.buffer = np.zeros(shape=2*self.n_frame_overlap + self.n_frame_len,\n",
                "                               dtype=np.float32)\n",
                "        self.offset = offset\n",
                "        self.reset()\n",
                "        \n",
                "    def _decode(self, frame, offset=0):\n",
                "        assert len(frame)==self.n_frame_len\n",
                "        self.buffer[:-self.n_frame_len] = self.buffer[self.n_frame_len:]\n",
                "        self.buffer[-self.n_frame_len:] = frame\n",
                "        logits = infer_signal(asr_model, self.buffer).cpu().numpy()[0]\n",
                "        # print(logits.shape)\n",
                "        decoded = self._greedy_decoder(\n",
                "            logits[self.n_timesteps_overlap:-self.n_timesteps_overlap], \n",
                "            self.vocab\n",
                "        )\n",
                "        return decoded[:len(decoded)-offset]\n",
                "    \n",
                "    @torch.no_grad()\n",
                "    def transcribe(self, frame=None, merge=True):\n",
                "        if frame is None:\n",
                "            frame = np.zeros(shape=self.n_frame_len, dtype=np.float32)\n",
                "        if len(frame) < self.n_frame_len:\n",
                "            frame = np.pad(frame, [0, self.n_frame_len - len(frame)], 'constant')\n",
                "        unmerged = self._decode(frame, self.offset)\n",
                "        if not merge:\n",
                "            return unmerged\n",
                "        return self.greedy_merge(unmerged)\n",
                "    \n",
                "    def reset(self):\n",
                "        '''\n",
                "        Reset frame_history and decoder's state\n",
                "        '''\n",
                "        self.buffer=np.zeros(shape=self.buffer.shape, dtype=np.float32)\n",
                "        self.prev_char = ''\n",
                "\n",
                "    @staticmethod\n",
                "    def _greedy_decoder(logits, vocab):\n",
                "        s = ''\n",
                "        for i in range(logits.shape[0]):\n",
                "            s += vocab[np.argmax(logits[i])]\n",
                "        return s\n",
                "\n",
                "    def greedy_merge(self, s):\n",
                "        s_merged = ''\n",
                "        \n",
                "        for i in range(len(s)):\n",
                "            if s[i] != self.prev_char:\n",
                "                self.prev_char = s[i]\n",
                "                if self.prev_char != '_':\n",
                "                    s_merged += self.prev_char\n",
                "        return s_merged"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Streaming Inference\n",
                "\n",
                "Streaming inference depends on a few factors, such as the frame length and buffer size. Experiment with a few values to see their effects in the below cells."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# duration of signal frame, seconds\n",
                "FRAME_LEN = 1.0\n",
                "# number of audio channels (expect mono signal)\n",
                "CHANNELS = 1\n",
                "\n",
                "CHUNK_SIZE = int(FRAME_LEN*SAMPLE_RATE)\n",
                "asr = FrameASR(model_definition = {\n",
                "                   'sample_rate': SAMPLE_RATE,\n",
                "                   'AudioToMelSpectrogramPreprocessor': cfg.preprocessor,\n",
                "                   'JasperEncoder': cfg.encoder,\n",
                "                   'labels': cfg.decoder.vocabulary\n",
                "               },\n",
                "               frame_len=FRAME_LEN, frame_overlap=2, \n",
                "               offset=4)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "asr.reset()\n",
                "\n",
                "p = pa.PyAudio()\n",
                "print('Available audio input devices:')\n",
                "input_devices = []\n",
                "for i in range(p.get_device_count()):\n",
                "    dev = p.get_device_info_by_index(i)\n",
                "    if dev.get('maxInputChannels'):\n",
                "        input_devices.append(i)\n",
                "        print(i, dev.get('name'))\n",
                "\n",
                "if len(input_devices):\n",
                "    dev_idx = -2\n",
                "    while dev_idx not in input_devices:\n",
                "        print('Please type input device ID:')\n",
                "        dev_idx = int(input())\n",
                "\n",
                "    empty_counter = 0\n",
                "\n",
                "    def callback(in_data, frame_count, time_info, status):\n",
                "        global empty_counter\n",
                "        signal = np.frombuffer(in_data, dtype=np.int16)\n",
                "        text = asr.transcribe(signal)\n",
                "        if len(text):\n",
                "            print(text,end='')\n",
                "            empty_counter = asr.offset\n",
                "        elif empty_counter > 0:\n",
                "            empty_counter -= 1\n",
                "            if empty_counter == 0:\n",
                "                print(' ',end='')\n",
                "        return (in_data, pa.paContinue)\n",
                "\n",
                "    stream = p.open(format=pa.paInt16,\n",
                "                    channels=CHANNELS,\n",
                "                    rate=SAMPLE_RATE,\n",
                "                    input=True,\n",
                "                    input_device_index=dev_idx,\n",
                "                    stream_callback=callback,\n",
                "                    frames_per_buffer=CHUNK_SIZE)\n",
                "\n",
                "    print('Listening...')\n",
                "\n",
                "    stream.start_stream()\n",
                "    \n",
                "    # Interrupt kernel and then speak for a few more words to exit the pyaudio loop !\n",
                "    try:\n",
                "        while stream.is_active():\n",
                "            time.sleep(0.1)\n",
                "    finally:        \n",
                "        stream.stop_stream()\n",
                "        stream.close()\n",
                "        p.terminate()\n",
                "\n",
                "        print()\n",
                "        print(\"PyAudio stopped\")\n",
                "    \n",
                "else:\n",
                "    print('ERROR: No audio input device found.')"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.7.7"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 4
}